-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DETA #20983
Add DETA #20983
Conversation
cc @alaradirik this PR is in a ready state, except for 2 things:
|
There is no problem with the model requiring torchvision to be installed. We have many models with specific dependencies, some of which you ported yourself ;-). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this! I added a few comments but it looks good to me overall.
Could you add an nms_threshold
argument to the object detection post processing? I can do a follow up PR to add it to all other post processing methods. Or it can be left as it is and I can add it shortly and update the object detection pipeline as well
is_level_ordered = ( | ||
level_ids[keep_inds][None] | ||
== torch.arange(len(spatial_shapes), device=level_ids.device)[:, None] | ||
) # LS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make the comment more descriptive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pinging the authors here, @xingyizhou @jozhang97, could you clarify what LS means here?
Args: | ||
outputs ([`DetrObjectDetectionOutput`]): | ||
Raw outputs of the model. | ||
threshold (`float`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a nms_threshold
argument and set it to 0.7 by default? We can leave DETA out of object detection mapping for now and I can do a followup PR shortly to add NMS support to all post_process_object_detection methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't just add NMS support to all post_process_object_detection methods I'm afraid, since for that one uses Torchvision's NMS op which is written in C and much faster than plain python. Also our existing models don't need NMS, DETA is actually the first one that needs it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NMS is a quite common to post-processing method for object detection and useful for models with noisy bounding box proposals. Given two bounding boxes with high prediction scores, they might be detecting the same object instance and NMS eliminates the redundant boxes.
I think it'd be nice to have an option to perform NMS but I'm fine with not adding it as well.
CC'ing @amyeroberts and @sgugger
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since for that one uses Torchvision's NMS op which is written in C and much faster than plain python
@NielsRogge I'm not sure I completely follow - is the issue here speed or the use of torchvision?
In general, I agree with @alaradirik, NMS is common enough that it's something I think we want to support, even if it's only for our PyTorch models at the moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok we can add NMS as an option, this way the API of all post_process_object_detection
methods will be the same. I assume we'll want to leverage torchvision for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it makes sense to leverage torchvision for this
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this new model. Left a couple of comments.
d64a408
to
ef7ff2b
Compare
@sgugger I've addressed all comments, except for adding support for the custom kernel. Could we perhaps add support for the custom kernel for the 3 models (Mask2Former, OneFormer and DETA) in a separate PR? |
In this case, remove the code trying to load the custom kernels in the modeling file and we can add it back in the PR that will deal with custom kernels. |
@sgugger ok, feel free to approve :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all your work on this!
7d97248
to
2d8d824
Compare
Failing test is unrelated/flaky, merging. |
* First draft * Add initial draft of conversion script * Convert all weights * Fix config * Add image processor * Fix DetaImageProcessor * Run make fix copies * Remove timm dependency * Fix dummy objects * Improve loss function * Remove conv_encoder attribute * Update conversion scripts * Improve postprocessing + docs * Fix copied from statements * Add tests * Improve postprocessing * Improve postprocessing * Update READMEs * More improvements * Fix rebase * Add is_torchvision_available * Add torchvision dependency * Fix typo and README * Fix bug * Add copied from * Fix style * Apply suggestions * Fix thanks to @ydshieh * Fix another dependency check * Simplify image processor * Add scipy * Improve code * Add threshold argument * Fix bug * Set default threshold * Improve integration test * Add another integration test * Update setup.py * Address review * Improve deformable attention function * Improve copied from * Use relative imports * Address review * Replace assertions * Address review * Update dummies * Remove dummies * Address comments, update READMEs * Remove custom kernel code * Add image processor tests * Add requires_backends * Add minor comment * Update scripts * Update organization name * Fix defaults, add doc tests * Add id2label for object 365 * Fix tests * Update task guide
* First draft * Add initial draft of conversion script * Convert all weights * Fix config * Add image processor * Fix DetaImageProcessor * Run make fix copies * Remove timm dependency * Fix dummy objects * Improve loss function * Remove conv_encoder attribute * Update conversion scripts * Improve postprocessing + docs * Fix copied from statements * Add tests * Improve postprocessing * Improve postprocessing * Update READMEs * More improvements * Fix rebase * Add is_torchvision_available * Add torchvision dependency * Fix typo and README * Fix bug * Add copied from * Fix style * Apply suggestions * Fix thanks to @ydshieh * Fix another dependency check * Simplify image processor * Add scipy * Improve code * Add threshold argument * Fix bug * Set default threshold * Improve integration test * Add another integration test * Update setup.py * Address review * Improve deformable attention function * Improve copied from * Use relative imports * Address review * Replace assertions * Address review * Update dummies * Remove dummies * Address comments, update READMEs * Remove custom kernel code * Add image processor tests * Add requires_backends * Add minor comment * Update scripts * Update organization name * Fix defaults, add doc tests * Add id2label for object 365 * Fix tests * Update task guide
What does this PR do?
This PR adds DETA. DETA is a slight change to Deformable DETR by using traditional IoU-based assignment as opposed to the Hungarian matching used in the original DETR, and incorporating NMS (non-maximum suppression) in the postprocessing.
Note: this model has a
torchvision
dependency for NMS.To do: